from torchvision.models import resnet18


def get_net(arch="resnet18", num_classes=10):
    if arch == "resnet18":
        return resnet18(num_classes=num_classes)
